get started classifying stuff with scikit-learn

we'll import some stuff, and then use cross-validation to test it.

First, we'll load up some data, then we'll set up a classifier.


In [5]:
import numpy as np
def load_dataset(path, dataset_name):
    '''
    data,labels = load_dataset(dataset_name)

    Load a given dataset

    Returns
    -------
    data : numpy ndarray
    labels : list of str
    '''
    data = []
    labels = []
    with open('{}/{}.tsv'.format(path, dataset_name)) as ifile:
        for line in ifile:
            tokens = line.strip().split('\t')
            data.append([float(tk) for tk in tokens[:-1]])
            labels.append(tokens[-1])
    data = np.array(data)
    labels = np.array(labels)
    return data, labels

features, labels = load_dataset("book/ch02/data", "seeds")

In [6]:
from sklearn.neighbors import KNeighborsClassifier
classifier = KNeighborsClassifier(n_neighbors=1)  # default is 5

now we'll do the cross-validation via scikit learn's tools


In [7]:
from sklearn.cross_validation import KFold

kf = KFold(len(features), n_folds=5, shuffle=True)
means = []  # list of mean accuracies, one per fold
for training, testing in kf:
    
    classifier.fit(features[training], labels[training])
    prediction = classifier.predict(features[testing])
    
    curmean = np.mean(prediction == labels[testing])
    means.append(curmean)
    
print("Mean accuracy: {:.1%}".format(np.mean(means)))


Mean accuracy: 91.4%